import numpy as np
import torch


# Mixup
def mixup(input, target, alpha, num_classes):
    if alpha > 0:
        lam = np.random.beta(alpha, alpha)
    else:
        lam = 1
    batch_size = input.shape[0]
    index = torch.randperm(batch_size)
    mixed_x = lam * input + (1 - lam) * input[index]
    labels = torch.zeros(batch_size, num_classes).to(target.device)
    labels.scatter_(1, target.reshape(-1, 1), 1)

    y_a, y_b = labels, labels[index]
    mixed_y = lam * y_a + (1 - lam) * y_b
    return mixed_x, mixed_y


class BatchMixup(object):
    def __init__(self, alpha=1.0, num_classes=10):
        self.alpha = alpha
        self.num_classes = num_classes
        print("Enable batchmixup with alpha {}".format(self.alpha))

    def __call__(self, input, target):
        """
        Args:
            images: list of tensor
        """
        return mixup(input, target, self.alpha, self.num_classes)
